{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-learning/blob/main/extras/solutions/01_pytorch_workflow_exercise_solutions.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N8LsPXZti9Sw"
      },
      "source": [
        "# 01. PyTorch Workflow Exercise Solutions\n",
        "\n",
        "The following is a solution notebook for the [PyTorch workflow exercises](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/exercises/01_pytorch_workflow_exercises.ipynb).\n",
        "\n",
        "Because of the flexibility of PyTorch, there may be more than one way to answer the question.\n",
        "\n",
        "Don't worry about trying to be *right* just try writing code that suffices the question."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Show when last updated (for documentation purposes)\n",
        "import datetime\n",
        "print(f\"Last updated: {datetime.datetime.now()}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HU08b4hLjWfJ",
        "outputId": "a8ece4f3-7733-4094-dbac-fcc96787a136"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Last updated: 2023-01-19 01:04:46.979628\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Glu2fM4dkNlx"
      },
      "outputs": [],
      "source": [
        "# Import necessary libraries\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from torch import nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "LqKhXY26m31s",
        "outputId": "39ab2e2f-d4fc-4126-aee9-7c07829881a8"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'cuda'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "# Setup device-agnostic code\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "device"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g7HUhxCxjeBx"
      },
      "source": [
        "## 1. Create a straight line dataset using the linear regression formula (`weight * X + bias`).\n",
        "  * Set `weight=0.3` and `bias=0.9` there should be at least 100 datapoints total. \n",
        "  * Split the data into 80% training, 20% testing.\n",
        "  * Plot the training and testing data so it becomes visual.\n",
        "\n",
        "Your output of the below cell should look something like:\n",
        "```\n",
        "Number of X samples: 100\n",
        "Number of y samples: 100\n",
        "First 10 X & y samples:\n",
        "X: tensor([[0.0000],\n",
        "        [0.0100],\n",
        "        [0.0200],\n",
        "        [0.0300],\n",
        "        [0.0400],\n",
        "        [0.0500],\n",
        "        [0.0600],\n",
        "        [0.0700],\n",
        "        [0.0800],\n",
        "        [0.0900]])\n",
        "y: tensor([[0.9000],\n",
        "        [0.9030],\n",
        "        [0.9060],\n",
        "        [0.9090],\n",
        "        [0.9120],\n",
        "        [0.9150],\n",
        "        [0.9180],\n",
        "        [0.9210],\n",
        "        [0.9240],\n",
        "        [0.9270]])\n",
        "```\n",
        "\n",
        "Of course the numbers in `X` and `y` may be different but ideally they're created using the linear regression formula."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KbDG5MV7jhvE",
        "outputId": "9a81d5bc-8a9f-49d8-8196-b8906312159d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of X samples: 100\n",
            "Number of y samples: 100\n",
            "First 10 X & y samples:\n",
            "X: tensor([[0.0000],\n",
            "        [0.0100],\n",
            "        [0.0200],\n",
            "        [0.0300],\n",
            "        [0.0400],\n",
            "        [0.0500],\n",
            "        [0.0600],\n",
            "        [0.0700],\n",
            "        [0.0800],\n",
            "        [0.0900]])\n",
            "y: tensor([[0.9000],\n",
            "        [0.9030],\n",
            "        [0.9060],\n",
            "        [0.9090],\n",
            "        [0.9120],\n",
            "        [0.9150],\n",
            "        [0.9180],\n",
            "        [0.9210],\n",
            "        [0.9240],\n",
            "        [0.9270]])\n"
          ]
        }
      ],
      "source": [
        "# Create the data parameters\n",
        "weight = 0.3\n",
        "bias = 0.9\n",
        "# Make X and y using linear regression feature\n",
        "X = torch.arange(0,1,0.01).unsqueeze(dim = 1)\n",
        "y = weight * X + bias\n",
        "print(f\"Number of X samples: {len(X)}\")\n",
        "print(f\"Number of y samples: {len(y)}\")\n",
        "print(f\"First 10 X & y samples:\\nX: {X[:10]}\\ny: {y[:10]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GlwtT1djkmLw",
        "outputId": "b0fce3d0-1b57-4051-a935-c05d40cf2dc7"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(80, 80, 20, 20)"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "# Split the data into training and testing\n",
        "train_split = int(len(X) * 0.8)\n",
        "X_train = X[:train_split]\n",
        "y_train = y[:train_split]\n",
        "X_test = X[train_split:]\n",
        "y_test = y[train_split:]\n",
        "len(X_train),len(y_train),len(X_test),len(y_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 428
        },
        "id": "29iQZFNhlYJ-",
        "outputId": "72479e77-f478-4162-ec07-89ede1b72b10"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 720x504 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAGbCAYAAAD3MIVlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df3RU9Z3/8debBITyI4KJooLAUUIBRX5EkLNa2OqqiBYr31bUYqxW4wG6+lV0W3VVoJ52LWrXlnbRFvnW9msRD3yPB7uidXERBUlQoAJq8Uc1mDYpth5gRSR5f/+YMU1Cfkwyc2funft8nDMnmXvvzHzI5ceL+37Pe8zdBQAAgMzqlusFAAAA5CNCFgAAQAAIWQAAAAEgZAEAAASAkAUAABCAwlwvoDXFxcU+dOjQXC8DAACgQ1u2bPmLu5e03B7KkDV06FBVVVXlehkAAAAdMrM/tradciEAAEAACFkAAAABIGQBAAAEgJAFAAAQAEIWAABAAEL57sL2NDQ0qLq6WgcOHMj1UhAB3bt317HHHqt+/frleikAgJiJXMj6y1/+IjPTiBEj1K0bF+LQNnfXJ598oj179kgSQQsAkFWRSyl/+9vfdNxxxxGw0CEz0xe+8AWdeOKJqq2tzfVyAAAxE7mkUl9fr+7du+d6GYiQXr166bPPPsv1MgAAMdNhyDKzZWZWa2avt7H/SjPbbma/N7OXzez0JvsuMLM3zWy3mX0nU4s2s0w9FWKA3y8AgFxI5UrWckkXtLP/XUlT3P00SYskPSxJZlYgaYmkaZJGSbrczEaltVoAAICI6DBkuft6SR+1s/9ld/9r8u4mSYOS30+UtNvd33H3Q5J+I2lGmusFAACIhEz3ZF0r6T+T358o6YMm+6qT21plZtebWZWZVdXV1WV4Wfnn6quv1kUXXdSpx0ydOlXz5s0LaEXtmzdvnqZOnZqT1wYAIBcyNsLBzP5RiZB1Vlce7+4PK1lqLCsr80ytK9c66gcqLy/X8uXLO/28//7v/y73zv2YVq1aFZk3Dbz33nsaNmyYKisrVVZWluvlAADQaRkJWWY2RtLPJU1z973JzXskDW5y2KDktlipqalp/H7NmjW67rrrmm3r1atXs+M/++yzlIJQUVFRp9cyYMCATj8GAAB0TdrlQjM7SdIqSbPd/a0muyolDTezYWbWQ9IsSU+l+3pRM3DgwMbb0Ucf3WzbwYMHdfTRR+vxxx/Xl7/8ZfXq1UtLly7V3r17dfnll2vQoEHq1auXRo8erUcffbTZ87YsF06dOlVz5szR7bffruLiYh177LGaP3++Ghoamh3TtFw4dOhQfe9731NFRYX69eunQYMG6Yc//GGz13nrrbc0ZcoU9ezZUyNGjNBvf/tb9enTp92rb/X19Zo/f7769++v/v3766abblJ9fX2zY5555hmdffbZ6t+/vwYMGKDzzz9fu3btatw/bNgwSdIZZ5whM2ssNVZWVuq8885TcXGx+vXrp7POOksbN25M4UwAAOJk7tNzVbiwUHOfnpuzNaQywuFxSRsljTCzajO71sxuMLMbkofcJekYST81s61mViVJ7n5Y0jxJayXtkvSEu+8I5FcRcd/97nc1Z84c7dy5U5dccokOHjyo8ePHa82aNdqxY4duvPFGVVRU6Pnnn2/3eX7961+rsLBQL7/8sn7yk5/oRz/6kVasWNHuYx588EGddtppevXVV/Uv//Ivuu222xpDS0NDg7761a+qsLBQmzZt0vLly7VgwQJ9+umn7T7n/fffr0ceeURLly7Vxo0bVV9fr1//+tfNjjlw4IBuuukmbd68WS+88IKKiop08cUX69ChQ5KkzZs3S0qEsZqaGq1atUqStG/fPs2ePVsvvviiNm/erLFjx+rCCy/U3r17BQDA55ZuWap6r9fSLUtztwh3D91twoQJ3padO3e2ua8z5sxxLyhIfM2WlStXeuJHnvDuu++6JF+8eHGHj73sssv82muvbbxfXl7u06dPb7w/ZcoUP/PMM5s95txzz232mClTpvjcuXMb7w8ZMsRnzZrV7DGnnHKKL1q0yN3dn3nmGS8oKPDq6urG/S+99JJL8kcffbTNtR5//PH+ve99r/F+fX29Dx8+3KdMmdLmY/bv3+/dunXzF1980d3//rOprKxs8zHu7g0NDT5w4EB/7LHH2j0uU79vAADRMGfNHC9YUOBz1gT/D72kKm8lz0Ru4numLF0q1dcnvuZay8bu+vp63XvvvRozZoyOOeYY9enTR6tWrdL777/f7vOMGTOm2f0TTjihw4+Tae8xb7zxhk444QSdeOLf3xR6xhlntPuRRh9//LFqamo0efLkxm3dunXTpEmTmh339ttv64orrtDJJ5+sfv366bjjjlNDQ0OHv8ba2lpVVFSotLRURUVF6tu3r2prazt8HAAgf7VWGlwyfYkO33VYS6Yvydm6YhuyKiqkgoLE11zr3bt3s/uLFy/W/fffr1tvvVXPP/+8tm7dqksuuaSxlNaWlg3zZtasJytTj8mEiy66SHV1dVq6dKleeeUVvfbaayosLOzw11heXq7Kyko9+OCDevnll7V161YNGjSow8cBAPJXKEqDrYhtyFqyRDp8OPE1bDZs2KCLL75Ys2fP1tixY3XyySfrrbfe6viBGfbFL35RH374oT788MPGbVVVVe2GsKKiIh1//PHatGlT4zZ3b+yxkqS9e/fqjTfe0O23365zzz1XI0eO1L59+3T48OHGY3r06CFJRzTMb9iwQd/+9rc1ffp0jR49Wn379m32bk0AQPxUTKhQgRWoYkIIrpw0kbE5Wcic0tJSrVixQhs2bFBxcbF+/OMf691339W4ceOyuo5/+qd/0ogRI1ReXq7Fixfrk08+0c0336zCwsJ253/deOON+v73v6/S0lKddtpp+ulPf6qamhodf/zxkqT+/furuLhYjzzyiAYPHqw9e/bo1ltvVWHh3387HnvsserVq5fWrl2roUOHqmfPnioqKlJpaal+9atfadKkSTpw4IBuu+22xkAGAIinJdOX5LQs2JbYXskKszvvvFMTJ07UtGnT9KUvfUm9e/fWlVdemfV1dOvWTatXr9ann36qiRMnqry8XHfccYfMTD179mzzcbfccou++c1v6lvf+pYmTZqkhoaGZuvv1q2bVqxYoe3bt+vUU0/V3LlztWjRIh111FGNxxQWFuqhhx7Sz3/+c51wwgmaMSPxiUzLli3T/v37NWHCBM2aNUvXXHONhg4dGtjPAAAQLmEYzZAq805ODc+GsrIyr6qqanXfrl27NHLkyCyvCJ/btm2bxo4dq6qqKk2YMCHXy0kZv28AID8ULixUvderwAp0+K7DHT8gC8xsi7sf8fEkXMlCu1avXq1nn31W7777rtatW6err75ap59+usaPH5/rpQEAYiis/VetIWShXfv27dO8efM0atQoXXnllRo5cqTWrl3b4WcyAgCQrrCOZkgV5ULEAr9vACB6wlgabA3lQgAAEClRKg22hhEOAAAglMI6miFVXMkCAAA5F6XRDKkiZAEAgJwL60fjpIOQBQAAci7q/VetIWQBAICsivpohlQRsmJq8eLFfBwNACAn8rE02BpCVsDMrN3b1Vdf3eXnvueee3TqqadmbrEdMDM9+eSTWXs9AEB+ysfSYGsY4RCwmpqaxu/XrFmj6667rtm2Xr165WJZAADkTNRHM6SKK1kBGzhwYOPt6KOPPmLb+vXrNWHCBPXs2VPDhg3THXfcoUOHDjU+ftWqVRozZox69eqlAQMGaMqUKfrzn/+s5cuXa8GCBdqxY0fjVbHly5e3uY777rtPAwcOVJ8+fXTVVVdp//79zfZXVlbqvPPOU3Fxsfr166ezzjpLGzdubNz/eWnxa1/7msys8f7bb7+tGTNmaODAgerdu7fGjx+vNWvWZOaHBwCItHwcy9AZhKwcWrt2ra688krNmzdPO3bs0LJly/Tkk0/q9ttvlyT96U9/0qxZs1ReXq5du3Zp/fr1mj17tiTpsssu0y233KIRI0aopqZGNTU1uuyyy1p9nSeeeEJ33nmnFixYoFdffVUjRozQAw880OyYffv2afbs2XrxxRe1efNmjR07VhdeeKH27t0rKRHCJOmRRx5RTU1N4/39+/dr2rRpeu6557Rt2zbNnDlTl156qd54441AfmYAgOiIS+9Vm9w9dLcJEyZ4W3bu3Nnmvs6Ys2aOFywo8Dlr5mTk+VKxcuVKT/zIE84++2xfuHBhs2NWr17tvXv39oaGBt+yZYtL8vfee6/V57v77rt99OjRHb7u5MmT/Vvf+lazbeecc44PGTKkzcc0NDT4wIED/bHHHmvcJslXrlzZ4etNmjTJFy1a1OFx2ZSp3zcAgNTl4t/aXJBU5a3kmdheyQpDut6yZYvuvfde9enTp/F2xRVX6MCBA/rTn/6k008/Xeeee65OPfVUzZw5Uz/72c9UV1fX6dfZtWuXJk+e3Gxby/u1tbWqqKhQaWmpioqK1LdvX9XW1ur9999v97kPHDig2267TaNGjVL//v3Vp08fVVVVdfg4AEB+ictYhs6IbcgKwzsbGhoadPfdd2vr1q2Nt+3bt+sPf/iDSkpKVFBQoGeffVbPPvusxowZo1/84hcaPny4tm3blvG1lJeXq7KyUg8++KBefvllbd26VYMGDWrWH9aa+fPna+XKlVq0aJH++7//W1u3btXEiRM7fBwAIL+E4eJF2MT23YVheGfD+PHj9cYbb+iUU05p8xgz0+TJkzV58mTdddddGj16tFasWKHTTz9dPXr0UH19fYevM3LkSG3atEnXXHNN47ZNmzY1O2bDhg166KGHNH36dEnSn//852bvgpSk7t27H/F6GzZs0FVXXaWZM2dKkg4ePKi3335bpaWlHa4LAJA/KiZUaOmWpXk/lqEzYhuywuCuu+7SRRddpCFDhujrX/+6CgsL9frrr2vz5s267777tGnTJv3ud7/T+eefr+OOO06vvfaaPvjgA40aNUpS4h1/f/zjH/Xqq6/qpJNOUt++fXXUUUcd8To33nijrrrqKp1xxhmaOnWqnnzySb3yyisaMGBA4zGlpaX61a9+pUmTJjWWAHv06NHseYYOHarnn39eU6ZM0VFHHaX+/furtLRUq1ev1owZM9S9e3ctWLBABw8eDPYHBwAInTBcvAib2JYLw+D888/X008/rXXr1mnixImaOHGifvCDH+ikk06SJBUVFemll17SRRddpOHDh+uWW27Rv/7rv+ob3/iGJGnmzJm68MILdc4556ikpESPP/54q69z2WWX6Z577tEdd9yhcePG6fe//71uvvnmZscsW7ZM+/fv14QJEzRr1ixdc801R0yEv//++7Vu3ToNHjxY48aNkyQ98MADOvbYY3X22Wdr2rRpOvPMM3X22Wdn+CcFAAiTuI9mSJUlmuLDpayszKuqqlrdt2vXLo0cOTLLK0LU8fsGADKncGGh6r1eBVagw3cdzvVycs7Mtrh7WcvtXMkCAACdEoY3j0UBIQsAALSJ0QxdR8gCAABtYjRD1xGyAABAmygNdl0kRzi4u8ws18tARDQ0NOR6CQAQWYxm6LrIXcnq2bOn9u7dqzC+KxLh4u46dOiQ9uzZo969e+d6OQAQeoxmyKzIjXD47LPPVF1dzcBLpKSwsFBFRUUqLi5Wt26R+z8FAGQVoxm6pq0RDpErF3bv3l3Dhg3L9TIAAMg7fDROZvFfewAAYojRDMEjZAEAEEOMZggeIQsAgBhiNEPwItf4DgAAECZ8diEAADHFaIbcIGQBAJDn6L/KDUIWAAB5jv6r3CBkAQCQRxjNEB6ELAAA8gilwfAgZAEAkEcoDYYHIQsAgAhq6x2DlAbDg5AFAEAEURYMP0IWAAARRFkw/Jj4DgAAkIYuT3w3s2VmVmtmr7ex/4tmttHMPjWz+S32vWdmvzezrWZGagIAoAuY2B5NqZQLl0u6oJ39H0n6Z0mL29j/j+4+trWEBwAAOkb/VTR1GLLcfb0SQaqt/bXuXinps0wuDAAAJNB/FU1BN767pGfNbIuZXd/egWZ2vZlVmVlVXV1dwMsCACCcmNieP4IOWWe5+3hJ0yTNNbMvtXWguz/s7mXuXlZSUhLwsgAACCdKg/kj0JDl7nuSX2slrZY0McjXAwAg6igN5o/CoJ7YzHpL6ubu+5LfnydpYVCvBwBAPlgyfQllwTyRygiHxyVtlDTCzKrN7Fozu8HMbkjuH2hm1ZJulnRn8ph+ko6TtMHMtknaLOlpd38muF8KAADRwmiG/MYwUgAAcqRwYaHqvV4FVqDDdx3O9XLQRV0eRgoAAIJB/1V+I2QBAJAFjGaIH0IWAABZwGiG+CFkAQCQBZQG44fGdwAAgDTQ+A4AQJYwmgESIQsAgIyj/woSIQsAgIyj/woSIQsAgLTMnSsVFia+fo7RDJAIWQAApGXpUqm+PvEVaIqQBQBAGioqpIKCxFegKUY4AAAApIERDgAApKm1/iugLYQsAABSRP8VOoOQBQBAiui/QmcQsgAAaKGtsuCSJdLhw4mvQEcIWQAAtEBZEJlAyAIAoAXKgsgERjgAAACkgREOAAC0grEMCAohCwAQa/RfISiELABArNF/haAQsgAAsdFaaZCxDAgKIQsAEBuUBpFNhCwAQGxQGkQ2McIBAAAgDYxwAADECqMZkGuELABAXqL/CrlGyAIA5CX6r5BrhCwAQOQxmgFhRMgCAEQepUGEESELABB5lAYRRoxwAAAASAMjHAAAeYHRDIgKQhYAIFLov0JUELIAAJFC/xWigpAFAAgtRjMgyghZAIDQojSIKCNkAQBCi9IgoowRDgAAAGlghAMAINQYzYB8Q8gCAIQC/VfIN4QsAEAo0H+FfEPIAgBkVVtlQUYzIN8QsgAAWUVZEHFByAIAZBVlQcQFIxwAAADSwAgHAEDWMZYBcUbIAgAEhv4rxFmHIcvMlplZrZm93sb+L5rZRjP71Mzmt9h3gZm9aWa7zew7mVo0ACAa6L9CnKVyJWu5pAva2f+RpH+WtLjpRjMrkLRE0jRJoyRdbmajurZMAEDYtVYaZCwD4qzDkOXu65UIUm3tr3X3Skmftdg1UdJud3/H3Q9J+o2kGeksFgAQXpQGgeaC7Mk6UdIHTe5XJ7cBAPIQpUGgudA0vpvZ9WZWZWZVdXV1uV4OAKCTKA0CzQUZsvZIGtzk/qDktla5+8PuXubuZSUlJQEuCwCQLkYzAB0LMmRVShpuZsPMrIekWZKeCvD1AABZQv8V0LFURjg8LmmjpBFmVm1m15rZDWZ2Q3L/QDOrlnSzpDuTx/Rz98OS5klaK2mXpCfcfUdwvxQAQLbQfwV0jI/VAQC0a+7cxBWrigr6rYDW8LE6AIAuoTQIdA0hCwDQLkqDQNdQLgQAAEgD5UIAQIcYzQBkDiELANCI/isgcwhZAIBG9F8BmUPIAoCYaq00yEfjAJlDyAKAmKI0CASLkAUAMUVpEAgWIxwAAADSwAgHAIgxRjMA2UfIAoAYoP8KyD5CFgDEAP1XQPYRsgAgzzCaAQgHQhYA5BlKg0A4ELIAIM9QGgTCgREOAAAAaWCEAwDkGcYyAOFGyAKAiKL3Cgg3QhYARBS9V0C4EbIAIAIYywBEDyELACKA0iAQPYQsAIgASoNA9DDCAQAAIA2McACAiGA0A5AfCFkAEDL0XwH5gZAFACFD/xWQHwhZAJBDjGYA8hchCwByiNIgkL8IWQCQQ5QGgfzFCAcAAIA0MMIBAHKM0QxAvBCyACBL6L8C4oWQBQBZQv8VEC+ELAAIAKMZABCyACAAlAYBELIAIACUBgEwwgEAACANjHAAgIAwmgFAawhZAJAm+q8AtIaQBQBpov8KQGsIWQDQCYxmAJAqQhYAdAKlQQCpImQBQCdQGgSQKkY4AAAApIERDgDQCYxlAJAuQhYAtILeKwDpImQBQCvovQKQLkIWgNhjLAOAIBCyAMQepUEAQegwZJnZMjOrNbPX29hvZvaQme02s+1mNr7Jvnoz25q8PZXJhQNAplAaBBCEwhSOWS7pJ5J+2cb+aZKGJ2+TJP0s+VWSPnH3sWmuEQACtWQJZUEAmdfhlSx3Xy/po3YOmSHpl56wSdLRZnZ8phYIAJnEaAYA2ZKJnqwTJX3Q5H51cpsk9TSzKjPbZGaXtPckZnZ98tiqurq6DCwLAI5E/xWAbAm68X1IcgLqFZJ+ZGYnt3Wguz/s7mXuXlZSUhLwsgDEFf1XALIlEyFrj6TBTe4PSm6Tu3/+9R1JL0gal4HXA4CUMJoBQC5lImQ9Jemq5LsMz5T0sbvXmFl/MztKksysWNI/SNqZgdcDgJRQGgSQSx2+u9DMHpc0VVKxmVVLultSd0ly9/+Q9FtJF0raLel/JH0z+dCRkpaaWYMSYe4H7k7IApA1FRWJgEVpEEAumLvneg1HKCsr86qqqlwvAwAAoENmtiXZg94ME98B5AVGMwAIG0IWgLxA/xWAsCFkAcgLjGYAEDaELACRw2gGAFFAyAIQOZQGAUQBIQtA5FAaBBAFjHAAAABIAyMcAEQSoxkARBUhC0Co0X8FIKoIWQBCjf4rAFFFyAIQGoxmAJBPCFkAQoPSIIB8QsgCEBqUBgHkE0Y4AAAApIERDgBCg7EMAOKAkAUg6+i9AhAHhCwAWUfvFYA4IGQBCBRjGQDEFSELQKAoDQKIK0IWgEBRGgQQV4xwAAAASAMjHAAEjtEMAPB3hCwAGUP/FQD8HSELQMbQfwUAf0fIAtAljGYAgPYRsgB0CaVBAGgfIQtAl1AaBID2McIBAAAgDYxwANBljGYAgM4jZAHoEP1XANB5hCwAHaL/CgA6j5AFoBlGMwBAZhCyADRDaRAAMoOQBaAZSoMAkBmMcAAAAEgDIxwAHIHRDAAQHEIWEGP0XwFAcAhZQIzRfwUAwSFkATHBaAYAyC5CFhATlAYBILsIWUBMUBoEgOxihAMAAEAaGOEAxARjGQAgHAhZQJ6h9woAwoGQBeQZeq8AIBwIWUCEMZYBAMKLkAVEGKVBAAgvQhYQYZQGASC8GOEAAACQhrRGOJjZMjOrNbPX29hvZvaQme02s+1mNr7JvnIz+0PyVt71XwIQb4xmAIBoSbVcuFzSBe3snyZpePJ2vaSfSZKZDZB0t6RJkiZKutvM+nd1sUCc0X8FANGSUshy9/WSPmrnkBmSfukJmyQdbWbHSzpf0nPu/pG7/1XSc2o/rAFoA/1XABAtmWp8P1HSB03uVye3tbX9CGZ2vZlVmVlVXV1dhpYFRBOjGQAg+kLz7kJ3f9jdy9y9rKSkJNfLAXKK0iAARF+mQtYeSYOb3B+U3NbWdgDtoDQIANGXqZD1lKSrku8yPFPSx+5eI2mtpPPMrH+y4f285DYA7aA0CADRl+oIh8clbZQ0wsyqzexaM7vBzG5IHvJbSe9I2i3pEUlzJMndP5K0SFJl8rYwuQ1AEqMZACA/MYwUyLHCwkT/VUFB4uoVACBa0hpGCiA49F8BQH4iZAFZxGgGAIgPQhaQRYxmAID4IGQBWURpEADig8Z3AACANND4DmQZoxkAIN4IWUBA6L8CgHgjZAEBof8KAOKNkAVkAKMZAAAtEbKADKA0CABoiZAFZAClQQBAS4xwAAAASAMjHIAMYCwDACBVhCygE+i9AgCkipAFdAK9VwCAVBGygDYwlgEAkA5CFtAGSoMAgHQQsoA2UBoEAKSDEQ4AAABpYIQD0A5GMwAAMo2QBYj+KwBA5hGyANF/BQDIPEIWYofRDACAbCBkIXYoDQIAsoGQhdihNAgAyAZGOAAAAKSBEQ6IJUYzAAByhZCFvEb/FQAgVwhZyGv0XwEAcoWQhbzBaAYAQJgQspA3KA0CAMKEkIW8QWkQABAmjHAAAABIAyMckFcYzQAACDtCFiKJ/isAQNgRshBJ9F8BAMKOkIXQYzQDACCKCFkIPUqDAIAoImQh9CgNAgCiiBEOAAAAaWCEA0KPsQwAgHxCyEJo0HsFAMgnhCyEBr1XAIB8QshCTjCWAQCQ7whZyAlKgwCAfEfIQk5QGgQA5DtGOAAAAKSBEQ7IGUYzAADiiJCFwNF/BQCIo5RClpldYGZvmtluM/tOK/uHmNnzZrbdzF4ws0FN9tWb2dbk7alMLh7RQP8VACCOOgxZZlYgaYmkaZJGSbrczEa1OGyxpF+6+xhJCyV9v8m+T9x9bPL2lQytGyHFaAYAABJSuZI1UdJud3/H3Q9J+o2kGS2OGSXpv5Lfr2tlP2KC0iAAAAmphKwTJX3Q5H51cltT2yRdmvz+q5L6mtkxyfs9zazKzDaZ2SVtvYiZXZ88rqquri7F5SNsKA0CAJCQqcb3+ZKmmNlrkqZI2iOpPrlvSPJtjVdI+pGZndzaE7j7w+5e5u5lJSUlGVoWso3SIAAACamErD2SBje5Pyi5rZG7f+jul7r7OEl3JLf9Lfl1T/LrO5JekDQu/WUjDBjNAABA21IJWZWShpvZMDPrIWmWpGbvEjSzYjP7/Lm+K2lZcnt/Mzvq82Mk/YOknZlaPHKL/isAANrWYchy98OS5klaK2mXpCfcfYeZLTSzz98tOFXSm2b2lqTjJN2b3D5SUpWZbVOiIf4H7k7IyhP0XwEA0DY+VgcpmTs3ccWqooJ+KwAAmuJjdZAWSoMAAHQOIQspoTQIAEDnUC4EAABIA+VCpIzRDAAApI+QhSPQfwUAQPoIWTgC/VcAAKSPkBVzrZUG+WgcAADSR8iKOUqDAAAEg5AVc5QGAQAIBiMcAAAA0sAIh5hjLAMAANlFyIoJeq8AAMguQlZM0HsFAEB2EbLyEGMZAADIPUJWHqI0CABA7hGy8hClQQAAco8RDgAAAGlghEOeYjQDAADhRMiKOPqvAAAIJ0JWxNF/BQBAOBGyIoTRDAAARAchK0IoDQIAEB2ErAihNAgAQHQwwgEAACANjHCIGEYzAAAQbYSskKL/CgCAaCNkhRT9VwAARBshKwQYzQAAQP4hZIUApUEAAPIPISsEKA0CAJB/GOEAAACQBkY4hASjGQAAiAdCVpbRfwUAQDwQsrKM/isAAOKBkBUgRlSCCRMAAAauSURBVDMAABBfhKwAURoEACC+CFkBojQIAEB8McIBAAAgDYxwCBBjGQAAQEuErAyg9woAALREyMoAeq8AAEBLhKxOYiwDAABIBSGrkygNAgCAVBCyOonSIAAASAUjHAAAANLACIcuYDQDAADoKkJWO+i/AgAAXUXIagf9VwAAoKsIWUmMZgAAAJmUUsgyswvM7E0z221m32ll/xAze97MtpvZC2Y2qMm+cjP7Q/JWnsnFZxKlQQAAkEkdhiwzK5C0RNI0SaMkXW5mo1octljSL919jKSFkr6ffOwASXdLmiRpoqS7zax/5pafOZQGAQBAJqVyJWuipN3u/o67H5L0G0kzWhwzStJ/Jb9f12T/+ZKec/eP3P2vkp6TdEH6y848SoMAACCTUglZJ0r6oMn96uS2prZJujT5/Vcl9TWzY1J8bNYxmgEAAAQtU43v8yVNMbPXJE2RtEdSfWeewMyuN7MqM6uqq6vL0LJaR/8VAAAIWioha4+kwU3uD0pua+TuH7r7pe4+TtIdyW1/S+WxTZ7jYXcvc/eykpKSTvwSOo/+KwAAELQOP1bHzAolvSXpHCUCUqWkK9x9R5NjiiV95O4NZnavpHp3vyvZ+L5F0vjkoa9KmuDuH7X3mnysDgAAiIouf6yOux+WNE/SWkm7JD3h7jvMbKGZfSV52FRJb5rZW5KOk3Rv8rEfSVqkRDCrlLSwo4AFAACQD/iAaAAAgDTwAdEAAABZRMgCAAAIACELAAAgAIQsAACAABCyAAAAAkDIAgAACAAhCwAAIACELAAAgAAQsgAAAAJAyAIAAAgAIQsAACAAhCwAAIAAhPIDos2sTtIfA36ZYkl/Cfg10Hmcl/Di3IQT5yW8ODfhFMR5GeLuJS03hjJkZYOZVbX2idnILc5LeHFuwonzEl6cm3DK5nmhXAgAABAAQhYAAEAA4hyyHs71AtAqzkt4cW7CifMSXpybcMraeYltTxYAAECQ4nwlCwAAIDCELAAAgADkfcgyswvM7E0z221m32ll/1FmtiK5/xUzG5r9VcZPCuflZjPbaWbbzex5MxuSi3XGUUfnpslxM83MzYy3qGdBKufFzL6e/HOzw8z+b7bXGFcp/H12kpmtM7PXkn+nXZiLdcaNmS0zs1oze72N/WZmDyXP23YzG5/pNeR1yDKzAklLJE2TNErS5WY2qsVh10r6q7ufIulBSf+W3VXGT4rn5TVJZe4+RtKTku7L7irjKcVzIzPrK+lGSa9kd4XxlMp5MbPhkr4r6R/cfbSkm7K+0BhK8c/MnZKecPdxkmZJ+ml2VxlbyyVd0M7+aZKGJ2/XS/pZpheQ1yFL0kRJu939HXc/JOk3kma0OGaGpP+T/P5JSeeYmWVxjXHU4Xlx93Xu/j/Ju5skDcryGuMqlT8zkrRIif+QHMzm4mIslfNynaQl7v5XSXL32iyvMa5SOTcuqV/y+yJJH2ZxfbHl7uslfdTOITMk/dITNkk62syOz+Qa8j1knSjpgyb3q5PbWj3G3Q9L+ljSMVlZXXylcl6aulbSfwa6Inyuw3OTvKQ+2N2fzubCYi6VPzOlkkrN7CUz22Rm7f0PHpmTyrm5R9I3zKxa0m8lfTs7S0MHOvtvUacVZvLJgEwzs29IKpM0JddrgWRm3SQ9IOnqHC8FRypUouwxVYkrv+vN7DR3/1tOVwVJulzScne/38wmS3rMzE5194ZcLwzByvcrWXskDW5yf1ByW6vHmFmhEpdy92ZldfGVynmRmZ0r6Q5JX3H3T7O0trjr6Nz0lXSqpBfM7D1JZ0p6iub3wKXyZ6Za0lPu/pm7vyvpLSVCF4KVyrm5VtITkuTuGyX1VOJDipFbKf1blI58D1mVkoab2TAz66FEw+FTLY55SlJ58vv/Jem/nAmtQevwvJjZOElLlQhY9JZkT7vnxt0/dvdidx/q7kOV6Jf7irtX5Wa5sZHK32X/T4mrWDKzYiXKh+9kc5Exlcq5eV/SOZJkZiOVCFl1WV0lWvOUpKuS7zI8U9LH7l6TyRfI63Khux82s3mS1koqkLTM3XeY2UJJVe7+lKRfKHHpdrcSDXKzcrfieEjxvPxQUh9JK5PvQ3jf3b+Ss0XHRIrnBlmW4nlZK+k8M9spqV7Sre7OVfmApXhubpH0iJn9byWa4K/mP/PBM7PHlfiPR3GyH+5uSd0lyd3/Q4n+uAsl7Zb0P5K+mfE1cJ4BAAAyL9/LhQAAADlByAIAAAgAIQsAACAAhCwAAIAAELIAAAACQMgCAAAIACELAAAgAP8fbA9G3qopc1MAAAAASUVORK5CYII=\n"
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ],
      "source": [
        "# Plot the training and testing data \n",
        "def plot_predictions(train_data = X_train,\n",
        "                 train_labels = y_train,\n",
        "                 test_data = X_test,\n",
        "                 test_labels = y_test,\n",
        "                 predictions = None):\n",
        "  plt.figure(figsize = (10,7))\n",
        "  plt.scatter(train_data,train_labels,c = 'b',s = 4,label = \"Training data\")\n",
        "  plt.scatter(test_data,test_labels,c = 'g',s = 4,label = \"Test data\")\n",
        "\n",
        "  if predictions is not None:\n",
        "    plt.scatter(test_data,predictions,c = 'r',s = 4,label = \"Predictions\")\n",
        "  plt.legend(prop = {\"size\" : 14})\n",
        "plot_predictions()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ImZoe3v8jif8"
      },
      "source": [
        "## 2. Build a PyTorch model by subclassing `nn.Module`. \n",
        "  * Inside should be a randomly initialized `nn.Parameter()` with `requires_grad=True`, one for `weights` and one for `bias`. \n",
        "  * Implement the `forward()` method to compute the linear regression function you used to create the dataset in 1. \n",
        "  * Once you've constructed the model, make an instance of it and check its `state_dict()`.\n",
        "  * **Note:** If you'd like to use `nn.Linear()` instead of `nn.Parameter()` you can."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qzd__Y5rjtB8",
        "outputId": "0c61680e-76ab-465a-e081-7bc8063b4a79"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(LinearRegressionModel(),\n",
              " OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))]))"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ],
      "source": [
        "# Create PyTorch linear regression model by subclassing nn.Module\n",
        "## Option 1\n",
        "class LinearRegressionModel(nn.Module):\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "    self.weight = nn.Parameter(data=torch.randn(1, \n",
        "                                              requires_grad=True,\n",
        "                                              dtype=torch.float\n",
        "                                              ))\n",
        "    \n",
        "    self.bias = nn.Parameter(data=torch.randn(1, \n",
        "                                              requires_grad=True,\n",
        "                                              dtype=torch.float\n",
        "                                              ))\n",
        "\n",
        "  def forward(self, x):\n",
        "    return self.weight * x + self.bias\n",
        "\n",
        "# ## Option 2\n",
        "# class LinearRegressionModel(nn.Module):\n",
        "#   def __init__(self):\n",
        "#     super().__init__()\n",
        "#     self.linear_layer = nn.Linear(in_features = 1,\n",
        "#                                   out_features = 1)\n",
        "#   def forward(self,x : torch.Tensor) -> torch.Tensor:\n",
        "#     return self.linear_layer(x)\n",
        "  \n",
        "torch.manual_seed(42)\n",
        "model_1 = LinearRegressionModel()\n",
        "model_1,model_1.state_dict()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "G-yjYL_Rl92t",
        "outputId": "90d2330a-efe1-4b83-ae58-9814adeb76b9"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "device(type='cpu')"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "next(model_1.parameters()).device"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5LdcDnmOmyQ2",
        "outputId": "b801eef6-e0a5-4667-a4db-0e841c72d1a9"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Parameter containing:\n",
              " tensor([0.3367], device='cuda:0', requires_grad=True), Parameter containing:\n",
              " tensor([0.1288], device='cuda:0', requires_grad=True)]"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "# Instantiate the model and put it to the target device\n",
        "model_1.to(device)\n",
        "list(model_1.parameters())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G6nYOrJhjtfu"
      },
      "source": [
        "## 3. Create a loss function and optimizer using `nn.L1Loss()` and `torch.optim.SGD(params, lr)` respectively. \n",
        "  * Set the learning rate of the optimizer to be 0.01 and the parameters to optimize should be the model parameters from the model you created in 2.\n",
        "  * Write a training loop to perform the appropriate training steps for 300 epochs.\n",
        "  * The training loop should test the model on the test dataset every 20 epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "ltvoZ-FWjv1j"
      },
      "outputs": [],
      "source": [
        "# Create the loss function and optimizer\n",
        "loss_fn = nn.L1Loss()\n",
        "optimizer = torch.optim.SGD(params = model_1.parameters(),\n",
        "                            lr = 0.01)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xpE83NvNnkdV",
        "outputId": "460841c8-8f4e-41ec-8645-6b3a062db28d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 0 | Train loss: 0.757 | Test loss: 0.725\n",
            "Epoch: 20 | Train loss: 0.525 | Test loss: 0.454\n",
            "Epoch: 40 | Train loss: 0.294 | Test loss: 0.183\n",
            "Epoch: 60 | Train loss: 0.077 | Test loss: 0.073\n",
            "Epoch: 80 | Train loss: 0.053 | Test loss: 0.116\n",
            "Epoch: 100 | Train loss: 0.046 | Test loss: 0.105\n",
            "Epoch: 120 | Train loss: 0.039 | Test loss: 0.089\n",
            "Epoch: 140 | Train loss: 0.032 | Test loss: 0.074\n",
            "Epoch: 160 | Train loss: 0.025 | Test loss: 0.058\n",
            "Epoch: 180 | Train loss: 0.018 | Test loss: 0.042\n",
            "Epoch: 200 | Train loss: 0.011 | Test loss: 0.026\n",
            "Epoch: 220 | Train loss: 0.004 | Test loss: 0.009\n",
            "Epoch: 240 | Train loss: 0.004 | Test loss: 0.006\n",
            "Epoch: 260 | Train loss: 0.004 | Test loss: 0.006\n",
            "Epoch: 280 | Train loss: 0.004 | Test loss: 0.006\n"
          ]
        }
      ],
      "source": [
        "# Training loop\n",
        "# Train model for 300 epochs\n",
        "torch.manual_seed(42)\n",
        "\n",
        "epochs = 300\n",
        "\n",
        "# Send data to target device\n",
        "X_train = X_train.to(device)\n",
        "X_test = X_test.to(device)\n",
        "y_train = y_train.to(device)\n",
        "y_test = y_test.to(device)\n",
        "\n",
        "for epoch in range(epochs):\n",
        "  ### Training\n",
        "\n",
        "  # Put model in train mode\n",
        "  model_1.train()\n",
        "\n",
        "  # 1. Forward pass\n",
        "  y_pred = model_1(X_train)\n",
        "\n",
        "  # 2. Calculate loss\n",
        "  loss = loss_fn(y_pred,y_train)\n",
        "\n",
        "  # 3. Zero gradients\n",
        "  optimizer.zero_grad()\n",
        "\n",
        "  # 4. Backpropagation\n",
        "  loss.backward()\n",
        "\n",
        "  # 5. Step the optimizer\n",
        "  optimizer.step()\n",
        "\n",
        "  ### Perform testing every 20 epochs\n",
        "  if epoch % 20 == 0:\n",
        "    # Put model in evaluation mode and setup inference context \n",
        "    model_1.eval()\n",
        "    with torch.inference_mode():\n",
        "      # 1. Forward pass\n",
        "      y_preds = model_1(X_test)\n",
        "      # 2. Calculate test loss\n",
        "      test_loss = loss_fn(y_preds,y_test)\n",
        "      # Print out what's happening\n",
        "      print(f\"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x4j4TM18jwa7"
      },
      "source": [
        "## 4. Make predictions with the trained model on the test data.\n",
        "  * Visualize these predictions against the original training and testing data (**note:** you may need to make sure the predictions are *not* on the GPU if you want to use non-CUDA-enabled libraries such as matplotlib to plot)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bbMPK5Qjjyx_",
        "outputId": "c93b598c-59d7-41a6-af42-4d74605acf57"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[1.1464],\n",
              "        [1.1495],\n",
              "        [1.1525],\n",
              "        [1.1556],\n",
              "        [1.1587],\n",
              "        [1.1617],\n",
              "        [1.1648],\n",
              "        [1.1679],\n",
              "        [1.1709],\n",
              "        [1.1740],\n",
              "        [1.1771],\n",
              "        [1.1801],\n",
              "        [1.1832],\n",
              "        [1.1863],\n",
              "        [1.1893],\n",
              "        [1.1924],\n",
              "        [1.1955],\n",
              "        [1.1985],\n",
              "        [1.2016],\n",
              "        [1.2047]], device='cuda:0')"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ],
      "source": [
        "# Make predictions with the model\n",
        "model_1.eval()\n",
        "\n",
        "with torch.inference_mode():\n",
        "  y_preds = model_1(X_test)\n",
        "y_preds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EUxlAFiCstd5",
        "outputId": "79d88bb0-01da-41d5-f2e4-1f7dc222af4d"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[1.1464],\n",
              "        [1.1495],\n",
              "        [1.1525],\n",
              "        [1.1556],\n",
              "        [1.1587],\n",
              "        [1.1617],\n",
              "        [1.1648],\n",
              "        [1.1679],\n",
              "        [1.1709],\n",
              "        [1.1740],\n",
              "        [1.1771],\n",
              "        [1.1801],\n",
              "        [1.1832],\n",
              "        [1.1863],\n",
              "        [1.1893],\n",
              "        [1.1924],\n",
              "        [1.1955],\n",
              "        [1.1985],\n",
              "        [1.2016],\n",
              "        [1.2047]])"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ],
      "source": [
        "y_preds.cpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 428
        },
        "id": "K3BdmQaDpFo8",
        "outputId": "f96069ea-217d-4c03-855e-df4c80162336"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 720x504 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAGbCAYAAAD3MIVlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXSU9fn//9dFwiaERQgiRIGCIIuokKLUBaqoICpWvy1bNfRjSzxAq79KrVUPUdAPbcVS21o/SFVqq1ah0MOhKvbDBxcqaMJa2RTFBYwQl1rEopJcvz9mTJMwSSbM3LM+H+fMSea+75l5kzvoi/u65hpzdwEAACC+miV7AQAAAJmIkAUAABAAQhYAAEAACFkAAAABIGQBAAAEIDfZC4ikc+fO3rNnz2QvAwAAoFHr169/393z625PyZDVs2dPlZWVJXsZAAAAjTKztyJtp1wIAAAQAEIWAABAAAhZAAAAASBkAQAABICQBQAAEICUfHdhQ6qqqrRnzx4dPHgw2UtBGmjevLm6dOmidu3aJXspAIAsk3Yh6/3335eZqV+/fmrWjAtxqJ+769///rf27t0rSQQtAEBCpV1K+ec//6njjjuOgIVGmZmOOeYYde/eXfv370/2cgAAWSbtkkplZaWaN2+e7GUgjbRu3VpffPFFspcBAMgyaReypNAVCiBa/L4AAJIhLUMWAABAqiNkAQAABICQlaamTJmiSy65pEmPGTlypGbMmBHQiho2Y8YMjRw5MimvDQBAMqTdCId001g/UFFRkRYtWtTk573nnnvk7k16zNKlS9PmTQNvvvmmevXqpdLSUhUWFiZ7OQAANBkhK2Dl5eXV369YsULf+973am1r3bp1reO/+OKLqIJQ+/btm7yWY489tsmPAQAAR4dyYcC6du1afevQoUOtbYcOHVKHDh302GOP6bzzzlPr1q21YMECffDBB5o4caIKCgrUunVrDRw4UA899FCt561bLhw5cqSmTZumm2++WZ07d1aXLl00c+ZMVVVV1TqmZrmwZ8+euuOOO1RcXKx27dqpoKBAd911V63XefXVVzVixAi1atVK/fr105NPPqm2bds2ePWtsrJSM2fOVMeOHdWxY0ddf/31qqysrHXM008/rXPOOUcdO3bUscceq4suukjbt2+v3t+rVy9J0le/+lWZWXWpsbS0VBdeeKE6d+6sdu3a6eyzz9batWujOBMAgKwyfbqUmxv6miSErBTwk5/8RNOmTdO2bdt0+eWX69ChQxoyZIhWrFihrVu36rrrrlNxcbFWrVrV4PM88sgjys3N1Ysvvqjf/OY3+uUvf6nHH3+8wcfMnz9fp5xyijZs2KAf//jHuvHGG6tDS1VVlb7xjW8oNzdX69at06JFi3T77bfrs88+a/A57777bi1cuFALFizQ2rVrVVlZqUceeaTWMQcPHtT111+vl19+Wc8++6zat2+vSy+9VJ9//rkk6eWXX5YUCmPl5eVaunSpJOnAgQO66qqr9MILL+jll1/WaaedposvvlgffPBBg2sCAGSZBQukysrQ12Rx95S7DR061Ouzbdu2evc1xbRp7jk5oa+JsnjxYg/9yEN2797tknzevHmNPnb8+PF+zTXXVN8vKirysWPHVt8fMWKEn3nmmbUeM2rUqFqPGTFihE+fPr36fo8ePXzChAm1HtOnTx+fM2eOu7s//fTTnpOT43v27Kne//e//90l+UMPPVTvWo8//ni/4447qu9XVlb6SSed5CNGjKj3MZ988ok3a9bMX3jhBXf/z8+mtLS03se4u1dVVXnXrl39D3/4Q4PHxev3BgCQJhL4P3pJZR4hzzR6JcvMHjSz/Wb2Sj37J5vZFjP7h5m9aGan1tg32sx2mtkuM7spTrkwLlIh4H6pbmN3ZWWl7rzzTg0ePFidOnVS27ZttXTpUr399tsNPs/gwYNr3e/WrVujHyfT0GN27Nihbt26qXv37tX7v/rVrzb4kUYff/yxysvLNXz48OptzZo10xlnnFHruNdff12TJk1S79691a5dOx133HGqqqpq9M+4f/9+FRcXq2/fvmrfvr3y8vK0f//+Rh8HAMhgkUqD994rHT4c+pok0ZQLF0ka3cD+3ZJGuPspkuZIul+SzCxH0r2SxkgaIGmimQ2IabVxVFws5eSEviZbmzZtat2fN2+e7r77bv3oRz/SqlWrtGnTJl1++eXVpbT61G2YN7NaPVnxekw8XHLJJaqoqNCCBQv00ksvaePGjcrNzW30z1hUVKTS0lLNnz9fL774ojZt2qSCgoJGHwcAyGCpdOWkhkZDlrs/L+nDBva/6O4fhe+uk1QQ/n6YpF3u/oa7fy7pT5LGxbjeuEmBgFuvNWvW6NJLL9VVV12l0047Tb1799arr76a8HWcfPLJevfdd/Xuu+9WbysrK2swhLVv317HH3+81q1bV73N3at7rCTpgw8+0I4dO3TzzTdr1KhR6t+/vw4cOKDDhw9XH9OiRQtJOqJhfs2aNfr+97+vsWPHauDAgcrLy6v1bk0AQBZKpSsnNcS78f0aSU+Fv+8u6Z0a+/aEt0VkZlPNrMzMyioqKuK8rPTSt29frVq1SmvWrNGOHTs0Y8YM7d69O+HruOCCC9SvXz8VFRVp8+bNWrdunX74wx8qNze3wflf1113nX7+859ryZIl2rlzp66//vpaQahjx47q3LmzFi5cqF27dum5557Ttddeq9zc/0wU6dKli1q3bq2VK1dq3759+vjjjyWFfjZ//OMftW3bNpWWlmrChAnVgQwAkKVS9MpJ3EKWmX1doZD146N5vLvf7+6F7l6Yn58fr2WlpVtvvVXDhg3TmDFjdO6556pNmzaaPHlywtfRrFkzLVu2TJ999pmGDRumoqIi3XLLLTIztWrVqt7H3XDDDfrOd76j7373uzrjjDNUVVVVa/3NmjXT448/ri1btmjQoEGaPn265syZo5YtW1Yfk5ubq1/96lf63e9+p27dumncuNBF0AcffFCffPKJhg4dqgkTJui//uu/1LNnz8B+BgCAFJMCoxmiZR7F1HAz6ylphbsPqmf/YEnLJI1x91fD24ZLus3dLwrf/4kkufvcxl6vsLDQy8rKIu7bvn27+vfv3+iaEYzNmzfrtNNOU1lZmYYOHZrs5USN3xsAyBC5uaH+q5yc0NWrekz/63QtWL9AxUOLde/YYK9wmdl6dz/i40livpJlZidKWirpqi8DVlippJPMrJeZtZA0QdLyWF8PibVs2TI988wz2r17t1avXq0pU6bo1FNP1ZAhQ5K9NABANoqy/2rB+gWq9EotWJ+8ZvhoRjg8JmmtpH5mtsfMrjGza83s2vAhsyR1kvRbM9tkZmWS5O6HJc2QtFLSdklPuPvWQP4UCMyBAwc0Y8YMDRgwQJMnT1b//v21cuXKRj+TEQCAmMUwmqF4aLFyLEfFQ5PXDB9VuTDRKBci3vi9AYA0FGVpMNkCKxcCAAAEIsrS4PS/Tlfu7FxN/2tqNcMTsgAAQGqKsjSYCv1XkRCyAABA8sUwmiEV+q8iIWQBAIDki/KjcSKVBu8de68Ozzoc+KiGpiJkAQCA5Euj0QzRImQBAIDESvPRDNEiZGWpefPm8XE0AIDkiLI0GEmqlgYjIWQFzMwavE2ZMuWon/u2227ToEERP+koEGamJUuWJOz1AAAZKs1HM0SLkBWw8vLy6tvChQuP2HbPPfckeYUAAASkvncMpvlohmgRsgLWtWvX6luHDh2O2Pb8889r6NChatWqlXr16qVbbrlFn3/+efXjly5dqsGDB6t169Y69thjNWLECO3bt0+LFi3S7bffrq1bt1ZfFVu0aFG96/j5z3+url27qm3btrr66qv1ySef1NpfWlqqCy+8UJ07d1a7du109tlna+3atdX7vywtfvOb35SZVd9//fXXNW7cOHXt2lVt2rTRkCFDtGLFivj88AAA6S2GsqCUXv1XkRCykmjlypWaPHmyZsyYoa1bt+rBBx/UkiVLdPPNN0uS3nvvPU2YMEFFRUXavn27nn/+eV111VWSpPHjx+uGG25Qv379qq+KjR8/PuLrPPHEE7r11lt1++23a8OGDerXr59+8Ytf1DrmwIEDuuqqq/TCCy/o5Zdf1mmnnaaLL75YH3zwgaRQCJOkhQsXqry8vPr+J598ojFjxuhvf/ubNm/erCuvvFJXXHGFduzYEcjPDACQRqIsC0rpNZohau6ecrehQ4d6fbZt21bvvqaYtmKa59ye49NWTIvL80Vj8eLFHvqRh5xzzjk+e/bsWscsW7bM27Rp41VVVb5+/XqX5G+++WbE5yspKfGBAwc2+rrDhw/37373u7W2nX/++d6jR496H1NVVeVdu3b1P/zhD9XbJPnixYsbfb0zzjjD58yZ0+hxiRSv3xsAQDBybs9x3SbPuT0n2UtpMkllHiHPZO2VrFSo865fv1533nmn2rZtW32bNGmSDh48qPfee0+nnnqqRo0apUGDBunKK6/Ufffdp4qKiia/zvbt2zV8+PBa2+re379/v4qLi9W3b1+1b99eeXl52r9/v95+++0Gn/vgwYO68cYbNWDAAHXs2FFt27ZVWVlZo48DAGSYGCa2S+lfGowka0NWKpzMqqoqlZSUaNOmTdW3LVu26LXXXlN+fr5ycnL0zDPP6JlnntHgwYP1wAMP6KSTTtLmzZvjvpaioiKVlpZq/vz5evHFF7Vp0yYVFBTU6g+LZObMmVq8eLHmzJmj5557Tps2bdKwYcMafRwAIMM0of8qI0uDEWRtyEqFkzlkyBDt2LFDffr0OeKWm5srKTQ2Yfjw4SopKVFpaam6deumxx9/XJLUokULVVZWNvo6/fv317p162ptq3t/zZo1+v73v6+xY8dq4MCBysvLU3l5ea1jmjdvfsTrrVmzRldffbWuvPJKDR48WAUFBXr99deb/LMAAKS5JvRfpUI1KRGyNmSlglmzZunRRx/VrFmz9Morr2jHjh1asmSJbrzxRkmhIHTHHXeotLRUb7/9tpYvX6533nlHAwYMkBR6x99bb72lDRs26P3339dnn30W8XWuu+46/f73v9fChQv12muvae7cuXrppZdqHdO3b1/98Y9/1LZt21RaWqoJEyaoRYsWtY7p2bOnVq1apffee08fffRR9eOWLVumDRs26B//+Ie+/e1v69ChQ/H+UQEAUkkME9ul1KgmJUSkRq1k3xLR+J4MdRvf3d1XrlzpZ599trdu3drz8vJ86NCh/utf/9rdQ3/W0aNHe5cuXbxFixbeu3dv/9nPflb92EOHDvmVV17pHTp0cEn+0EMP1fva//3f/+35+fnepk0bnzhxopeUlNRqfN+0aZMPGzbMW7Vq5V/5ylf84Ycf9oEDB3pJSUn1McuXL/c+ffp4bm5u9WPffPNNP//88/2YY47x7t27+1133eVjx471oqKiWH9ccZXOvzcAkHJyctyl0FfU2/huoX2ppbCw0MvKyiLu2759u/r375/gFSHd8XsDAHE0fXqo96q4uMErV9P/Ol0L1i9Q8dDijOq1qsvM1rt7Yd3tlAsBAEDTZMnE9lgRsgAAQP1iGM2QNb1X9SBkAQCA+kU5miFbxjI0BSELAADUL8rRDNleGoyEkAUAAEJiGM2Q7aXBSHh3IbICvzcAEIXc3FBpMCcnFKwQFd5dCAAAGhZlaTBS/xWORMgCAAAhjGaIK0IWAADZiNEMgSNkZZglS5bIzKrvL1q0SG3bto3pOZ999lmZmd5///1YlwcASBWMZggcIStBpkyZIjOTmal58+b6yle+opkzZ+rgwYOBvu748eP1xhtvRH18z549NW/evFrbvva1r6m8vFydOnWK9/IAAMnCaIbAEbISaNSoUSovL9cbb7yhO+64Q7/97W81c+bMI447fPiw4vWuz9atW6tLly4xPUeLFi3UtWvXWlfIAABphNEMSUHISqCWLVuqa9euOuGEEzRp0iRNnjxZf/nLX3Tbbbdp0KBBWrRokXr37q2WLVvq4MGD+vjjjzV16lR16dJFeXl5GjFihOqOtnj44YfVo0cPHXPMMbrkkku0b9++WvsjlQuffPJJnXHGGWrdurU6deqkSy+9VIcOHdLIkSP11ltv6Uc/+lH1VTcpcrlw6dKlOuWUU9SyZUudcMIJuvPOO2sFw549e+qOO+5QcXGx2rVrp4KCAt1111211rFgwQL17dtXrVq1UufOnXXRRRfpMG8ZBoD4ozSYFISsJGrdurW++OILSdLu3bv16KOPavHixdq8ebNatmypsWPHau/evVqxYoU2btyoc889V+edd57Ky8slSS+99JKmTJmiqVOnatOmTbr00ks1a9asBl/z6aef1mWXXaYLLrhA69ev1+rVqzVixAhVVVVp6dKlKigo0KxZs1ReXl79OnWtX79e3/zmN3XFFVfoH//4h376059q7ty5+s1vflPruPnz5+uUU07Rhg0b9OMf/1g33nij1q5dK0kqKyvT9OnTVVJSop07d2rVqlUaPXp0rD9SAEAklAaTw91T7jZ06FCvz7Zt2+rd1yTTprnn5IS+JkBRUZGPHTu2+v5LL73knTp18m9961teUlLiubm5/t5771XvX7Vqlbdp08Y//fTTWs9z6qmn+s9+9jN3d584caKPGjWq1v5rrrnGQ6c15KGHHvI2bdpU3//a177m48ePr3edPXr08LvuuqvWttWrV7skr6iocHf3SZMm+de//vVax5SUlHj37t1rPc+ECRNqHdOnTx+fM2eOu7v/+c9/9nbt2vm//vWvetcST3H7vQGADDZtxTTPuT3Hp61IzP8bM4WkMo+QZ7L3SlaUl07j6emnn1bbtm3VqlUrDR8+XOeee65+/etfS5IKCgp03HHHVR+7fv16ffrpp8rPz1fbtm2rb6+88opef/11SaEp5sOHD6/1GnXv17Vx40adf/75Mf05tm/frrPOOqvWtrPPPlt79+7Vv/71r+ptgwcPrnVMt27dtH//fknSBRdcoB49eqhXr16aPHmyfv/73+vAgQMxrQsAoJhGM1AajK/sDVlRXjqNp3PPPVebNm3Szp07dejQIS1durS6Kb1Nmza1jq2qqtJxxx2nTZs21brt2LFDc+bMSdiam6pmc3zz5s2P2FdVVSVJysvL04YNG/TEE0/oxBNP1Ny5c3XyySfr3XffTeh6ASDjxNB/hfjK3pAV5bsq4umYY45Rnz591KNHjyMCSF1DhgzRvn371KxZM/Xp06fW7ctg1r9/f61bt67W4+rer+v000/XqlWr6t3fokULVVZWNvgc/fv319///vda29asWaOCggLl5eU1+NiacnNzdd5552nu3LnasmWLDh48qBUrVkT9eABABPRfpYzsDVkpbtSoUTrrrLM0btw4PfXUU9q9e7fWrl2rkpISvfDCC5KkH/zgB/rf//1fzZ07V6+99poWLlyoZcuWNfi8t9xyixYvXqxbb71V27Zt09atWzV//nx9+umnkkLvCnzhhRe0d+/eeoeP3nDDDXruued022236dVXX9Ujjzyiu+++WzfeeGPUf74VK1bonnvu0caNG/XWW2/p0Ucf1YEDB/gQZwCIVn1lQUYzpAxCVooyMz355JM677zz9L3vfU/9+vXTt771Le3cuVPdunWTJJ155pl64IEHdN9992nw4MFaunSpbrvttgaf9+KLL9ayZcv01FNP6fTTT9eIESO0evVqNWsW+lWYPXu23nnnHfXu3Vv5+fkRn2PIkCFavHix/vznP2vQoEG66aabdNNNN2nGjBlR//k6dOigv/zlLxo1apROPvlkzZs3T7/73e90zjnnRP0cAJDVmtBbzGiG5DCP09DLeCosLPS686C+tH37dq52oMn4vQGQcaZPDwWs4uJGr1rlzs5VpVcqx3J0eBbzCOPNzNa7e2Hd7VzJAgAgHTWht5jSYHIQsgAASHUxjGWQKA0mCyELAIBUF2P/FZKDkAUAQKprwmxHRjOkjrQMWanYrI/U9eUAVABIC5FKg/RfpaW0e3fh7t27lZeXp06dOtWaLg7U5e764osvtG/fPrm7TjzxxGQvCQAal5sbKg3m5ISCVT2m/3W6FqxfoOKhxfRaJVl97y7MTcZiYlFQUKA9e/aooqIi2UtBGsjNzVX79u3VuXPnZC8FAKJTXPyf0QwNqFkWJGSlprQLWc2bN1evXr2SvQwAAIJx771RlwW/vJKF1JSWPVkAAGSEKEczMLE9PTXak2VmD0q6RNJ+dx8UYf/Jkh6SNETSLe4+r8a+NyUdkFQp6XCkemUkDfVkAQCQMaLsv2Jie2qLZeL7IkmjG9j/oaQfSJpXz/6vu/tp0QYsAACyRpSjGXjHYHpqNGS5+/MKBan69u9391JJX8RzYQAAZJQYRjNQGkxPQfdkuaRnzGy9mU1t6EAzm2pmZWZWxjsHAQAZJ8qp7UxszxxBh6yz3X2IpDGSppvZufUd6O73u3uhuxfm5+cHvCwAABIsytIgE9szR6Ahy933hr/ul7RM0rAgXw8AgJQVZWmQ/qvMEVjIMrM2Zpb35feSLpT0SlCvBwBAymA0AxTdCIfHJI2U1FnSPkklkppLkrv/j5l1lVQmqZ2kKkmfSBoQPn5Z+GlyJT3q7ndGsyhGOAAA0hqjGbLKUX+sjrtPbGT/e5IKIuz6l6RTo14hAACZIsqPxmFqe2ZLuw+IBgAASCWxDCMFAAAxYjRD9iFkAQCQAIxmyD6ELAAAEoDRDNmHkAUAQJwxmgESIQsAgLijNAiJkAUAQNxRGoREyAIAICaRhrtTGoREyAIAICYLFoSGuy+gMog6CFkAAMSguDj06TmNDHdHFmLiOwAAQAyY+A4AQIwi9V8B9SFkAQAQJfqv0BSELAAAokT/FZqCkAUAQB31lQXvvVc6fDj0FWgMIQsAgDooCyIeCFkAANRBWRDxwAgHAACAGDDCAQCACBjLgKAQsgAAWY3+KwSFkAUAyGr0XyEohCwAQNaIVBpkLAOCQsgCAGQNSoNIJEIWACBrUBpEIhGyAAAZidIgko2QBQDISJQGkWyELABARqI0iGRj4jsAAEAMmPgOAMhYTG1HKiJkAQDSHv1XSEWELABA2qP/CqmIkAUASCuMZkC6IGQBANIKpUGkC0IWACCtUBpEumCEAwAAQAwY4QAASDuMZkA6I2QBAFIW/VdIZ4QsAEDKov8K6YyQBQBICYxmQKYhZAEAUgKlQWQaQhYAICVQGkSmIWQBABKqvncMUhpEpiFkAQASirIgsgUhCwCQUJQFkS2Y+A4AABADJr4DABKOie3IZoQsAEBg6L9CNms0ZJnZg2a238xeqWf/yWa21sw+M7OZdfaNNrOdZrbLzG6K16IBAOmB/itks2iuZC2SNLqB/R9K+oGkeTU3mlmOpHsljZE0QNJEMxtwdMsEAKQ6JrYDtTUastz9eYWCVH3797t7qaQv6uwaJmmXu7/h7p9L+pOkcbEsFgCQuigNArUF2ZPVXdI7Ne7vCW+LyMymmlmZmZVVVFQEuCwAQBAoDQK1pUzju7vf7+6F7l6Yn5+f7OUAAJqI0iBQW5Aha6+kE2rcLwhvAwCkOUYzAI0LMmSVSjrJzHqZWQtJEyQtD/D1AAAJQv8V0LhoRjg8JmmtpH5mtsfMrjGza83s2vD+rma2R9IPJd0aPqadux+WNEPSSknbJT3h7luD+6MAABKF/iugcXysDgCgQdOnh65YFRfTbwVEwsfqAACOCqVB4OgQsgAADaI0CBwdyoUAAAAxoFwIAGgUoxmA+CFkAQCq0X8FxA8hCwBQjf4rIH4IWQCQpSKVBvloHCB+CFkAkKUoDQLBImQBQJaiNAgEi5AFAFmA0iCQeIQsAMgClAaBxCNkAUAWoDQIJB4T3wEAAGLAxHcAyBJMbQdSAyELADIM/VdAaiBkAUCGof8KSA2ELABIU/WVBRnNAKQGQhYApCnKgkBqI2QBQJqiLAikNkY4AAAAxIARDgCQxhjLAKQfQhYApAH6r4D0Q8gCgDRA/xWQfghZAJBiIpUGGcsApB9CFgCkGEqDQGYgZAFAiqE0CGQGQhYAJBGlQSBzEbIAIIkoDQKZi5AFAElEaRDIXEx8BwAAiAET3wEgyZjaDmQXQhYAJAj9V0B2IWQBQILQfwVkF0IWAASA0QwACFkAEABKgwAIWQAQAEqDABjhAAAAEANGOABAQBjNACASQhYAxIj+KwCRELIAIEb0XwGIhJAFAE3AaAYA0SJkAUATUBoEEC1CFgA0AaVBANEiZAFABPW9Y5DSIIBoEbIAIALKggBiRcgCgAgoCwKIFRPfAQAAYsDEdwCoBxPbAQSh0ZBlZg+a2X4ze6We/WZmvzKzXWa2xcyG1NhXaWabwrfl8Vw4AMQL/VcAghDNlaxFkkY3sH+MpJPCt6mS7qux79/uflr4dtlRrxIAAkT/FYAgNBqy3P15SR82cMg4SQ97yDpJHczs+HgtEADiiYntABIlHj1Z3SW9U+P+nvA2SWplZmVmts7MLm/oScxsavjYsoqKijgsCwCORGkQQKIE3fjeI9xtP0nSL82sd30Huvv97l7o7oX5+fkBLwtAtqI0CCBR4hGy9ko6ocb9gvA2ufuXX9+Q9Kyk0+PwegBw1CgNAkiUeISs5ZKuDr/L8ExJH7t7uZl1NLOWkmRmnSWdJWlbHF4PAKLCaAYAydToMFIze0zSSEmdJe2TVCKpuSS5+/+YmUn6jULvQPxU0nfcvczMviZpgaQqhcLcL939gWgWxTBSAPGQmxvqv8rJCV29AoAg1DeMNLexB7r7xEb2u6Qj/p3o7i9KOqUpiwSAeCouDjW4038FIBmY+A4gIzCaAUCqIWQByAiMZgCQaghZADICoxkApBpCFoC0Q2kQQDogZAFIO5QGAaQDQhaAtENpEEA6aHROVjIwJwsAAKSL+uZkcSULQEpjajuAdEXIApDS6L8CkK4IWQBSGv1XANIVIQtAymA0A4BMQsgCkDIoDQLIJIQsACmD0iCATMIIBwAAgBgwwgFAymAsA4BsQMgCkHD0XgHIBoQsAAlH7xWAbEDIAhAoxjIAyFaELACBojQIIFsRsgAEitIggO6+HP8AAA2iSURBVGxFyAIQN5QGAeA/CFkA4obSIAD8ByELQNxQGgSA/2DiOwAAQAyY+A4grpjaDgANI2QBOCr0XwFAwwhZAI4K/VcA0DBCFoBGMZoBAJqOkAWgUZQGAaDpCFkAGkVpEACajhEOAAAAMWCEA4CoMJoBAOKDkAWgFvqvACA+CFkAaqH/CgDig5AFZDFGMwBAcAhZQBajNAgAwSFkAVmM0iAABIeQBWQJSoMAkFiELCBLUBoEgMQiZAFZgtIgACQWE98BAABiwMR3IEswsR0AUgMhC8gw9F4BQGogZAEZht4rAEgNhCwgjTGWAQBSFyELSGOUBgEgdRGygDRGaRAAUhcjHAAAAGIQ0wgHM3vQzPab2Sv17Dcz+5WZ7TKzLWY2pMa+IjN7LXwrOvo/ApDdGM0AAOkl2nLhIkmjG9g/RtJJ4dtUSfdJkpkdK6lE0hmShkkqMbOOR7tYIJvRfwUA6SWqkOXuz0v6sIFDxkl62EPWSepgZsdLukjS39z9Q3f/SNLf1HBYA1AP+q8AIL3Eq/G9u6R3atzfE95W3/YjmNlUMyszs7KKioo4LQtIT4xmAID0lzLvLnT3+9290N0L8/Pzk70cIKkoDQJA+otXyNor6YQa9wvC2+rbDqABlAYBIP3FK2Qtl3R1+F2GZ0r62N3LJa2UdKGZdQw3vF8Y3gagAZQGASD9RTvC4TFJayX1M7M9ZnaNmV1rZteGD3lS0huSdklaKGmaJLn7h5LmSCoN32aHtwEIYzQDAGQmhpECSZabG+q/yskJXb0CAKSXmIaRAggO/VcAkJkIWUACMZoBALIHIQtIIEYzAED2IGQBCURpEACyByELCAilQQDIboQsICCUBgEguxGygIBQGgSA7MacLAAAgBgwJwsIEFPbAQB1EbKAOKD/CgBQFyELiAP6rwAAdRGygCaoryzIaAYAQF2ELKAJKAsCAKJFyAKagLIgACBajHAAAACIASMcgCZiLAMAIBaELKAe9F8BAGJByALqQf8VACAWhCxAkUuDjGUAAMSCkAWI0iAAIP4IWYAoDQIA4o+QhaxDaRAAkAiELGQdSoMAgEQgZCHrUBoEACQCE98BAABiwMR3ZCWmtgMAkoWQhYxG/xUAIFkIWcho9F8BAJKFkIWMwWgGAEAqIWQhY1AaBACkEkIWMgalQQBAKmGEAwAAQAwY4YCMwmgGAECqI2QhLdF/BQBIdYQspCX6rwAAqY6QhZTHaAYAQDoiZCHlURoEAKQjQhZSHqVBAEA6ImQhZdT3jkFKgwCAdETIQsqgLAgAyCSELKQMyoIAgEzCxHcAAIAYMPEdKYWJ7QCATEfIQlLQfwUAyHSELCQF/VcAgExHyELgmNgOAMhGhCwEjtIgACAbEbIQOEqDAIBsFFXIMrPRZrbTzHaZ2U0R9vcws1VmtsXMnjWzghr7Ks1sU/i2PJ6LR3qgNAgAyEaNhiwzy5F0r6QxkgZImmhmA+ocNk/Sw+4+WNJsSXNr7Pu3u58Wvl0Wp3UjRTGaAQCAkGiuZA2TtMvd33D3zyX9SdK4OscMkPR/4e9XR9iPLEH/FQAAIdGErO6S3qlxf094W02bJV0R/v4bkvLMrFP4fiszKzOzdWZ2eX0vYmZTw8eVVVRURLl8pBr6rwAACIlX4/tMSSPMbKOkEZL2SqoM7+sRHjU/SdIvzax3pCdw9/vdvdDdC/Pz8+O0LASJ0QwAANQvmpC1V9IJNe4XhLdVc/d33f0Kdz9d0i3hbf8Mf90b/vqGpGclnR77spEKKA0CAFC/aEJWqaSTzKyXmbWQNEFSrXcJmllnM/vyuX4i6cHw9o5m1vLLYySdJWlbvBaP5KI0CABA/RoNWe5+WNIMSSslbZf0hLtvNbPZZvbluwVHStppZq9KOk7SneHt/SWVmdlmhRrif+ruhKw0RGkQAICmMXdP9hqOUFhY6GVlZcleBmrIzQ2VBnNyQsEKAACEmNn6cP95LUx8R1QoDQIA0DRcyQIAAIgBV7IQNaa2AwAQO0IWjsBoBgAAYkfIwhHovwIAIHaErCzHaAYAAIJByMpylAYBAAgGISvLURoEACAYjHAAAACIASMcshxjGQAASCxCVpag9woAgMQiZGUJeq8AAEgsQlYGYiwDAADJR8jKQJQGAQBIPkJWBqI0CABA8hGy0hylQQAAUhMhK81RGgQAIDURstIcpUEAAFITE98BAABiwMT3DMDUdgAA0gchK43QfwUAQPogZKUR+q8AAEgfhKwUxWgGAADSGyErRVEaBAAgvRGyUhSlQQAA0hsjHAAAAGLACIcUxmgGAAAyDyErBdB/BQBA5iFkpQD6rwAAyDyErARjNAMAANmBkJVglAYBAMgOhKwEozQIAEB2IGQFiNIgAADZi5AVIEqDAABkL0JWgCgNAgCQvZj4DgAAEAMmvgeIie0AAKAuQlYc0HsFAADqImTFAb1XAACgLkJWEzGWAQAARIOQ1USUBgEAQDQIWU1EaRAAAESDEQ4AAAAxYITDUWA0AwAAOFqErAbQfwUAAI4WIasB9F8BAICjRcgKYzQDAACIp6hClpmNNrOdZrbLzG6KsL+Hma0ysy1m9qyZFdTYV2Rmr4VvRfFcfDxRGgQAAPHUaMgysxxJ90oaI2mApIlmNqDOYfMkPezugyXNljQ3/NhjJZVIOkPSMEklZtYxfsuPH0qDAAAgnqK5kjVM0i53f8PdP5f0J0nj6hwzQNL/hb9fXWP/RZL+5u4fuvtHkv4maXTsy44/SoMAACCeoglZ3SW9U+P+nvC2mjZLuiL8/Tck5ZlZpygfm3CMZgAAAEGLV+P7TEkjzGyjpBGS9kqqbMoTmNlUMyszs7KKioo4LSsy+q8AAEDQoglZeyWdUON+QXhbNXd/192vcPfTJd0S3vbPaB5b4znud/dCdy/Mz89vwh+h6ei/AgAAQWv0Y3XMLFfSq5LOVygglUqa5O5baxzTWdKH7l5lZndKqnT3WeHG9/WShoQP3SBpqLt/2NBr8rE6AAAgXRz1x+q4+2FJMyStlLRd0hPuvtXMZpvZZeHDRkraaWavSjpO0p3hx34oaY5CwaxU0uzGAhYAAEAm4AOiAQAAYsAHRAMAACQQIQsAACAAhCwAAIAAELIAAAACQMgCAAAIACELAAAgAIQsAACAABCyAAAAAkDIAgAACAAhCwAAIACELAAAgAAQsgAAAAKQkh8QbWYVkt4K+GU6S3o/4NdA03FeUhfnJjVxXlIX5yY1BXFeerh7ft2NKRmyEsHMyiJ9YjaSi/OSujg3qYnzkro4N6kpkeeFciEAAEAACFkAAAAByOaQdX+yF4CIOC+pi3OTmjgvqYtzk5oSdl6yticLAAAgSNl8JQsAACAwhCwAAIAAZHzIMrPRZrbTzHaZ2U0R9rc0s8fD+18ys56JX2X2ieK8/NDMtpnZFjNbZWY9krHObNTYualx3JVm5mbGW9QTIJrzYmbfCv+92WpmjyZ6jdkqiv+enWhmq81sY/i/aRcnY53ZxsweNLP9ZvZKPfvNzH4VPm9bzGxIvNeQ0SHLzHIk3StpjKQBkiaa2YA6h10j6SN37yNpvqSfJXaV2SfK87JRUqG7D5a0RNLPE7vK7BTluZGZ5Um6TtJLiV1hdormvJjZSZJ+Iuksdx8o6fqELzQLRfl35lZJT7j76ZImSPptYleZtRZJGt3A/jGSTgrfpkq6L94LyOiQJWmYpF3u/oa7fy7pT5LG1TlmnKTfh79fIul8M7MErjEbNXpe3H21u38avrtOUkGC15itovk7I0lzFPoHyaFELi6LRXNevifpXnf/SJLcfX+C15itojk3Lqld+Pv2kt5N4Pqylrs/L+nDBg4ZJ+lhD1knqYOZHR/PNWR6yOou6Z0a9/eEt0U8xt0PS/pYUqeErC57RXNearpG0lOBrghfavTchC+pn+Duf03kwrJcNH9n+krqa2Z/N7N1ZtbQv+ARP9Gcm9skfdvM9kh6UtL3E7M0NKKp/y9qstx4PhkQb2b2bUmFkkYkey2QzKyZpF9ImpLkpeBIuQqVPUYqdOX3eTM7xd3/mdRVQZImSlrk7neb2XBJfzCzQe5eleyFIViZfiVrr6QTatwvCG+LeIyZ5Sp0KfeDhKwue0VzXmRmoyTdIukyd/8sQWvLdo2dmzxJgyQ9a2ZvSjpT0nKa3wMXzd+ZPZKWu/sX7r5b0qsKhS4EK5pzc42kJyTJ3ddKaqXQhxQjuaL6f1EsMj1klUo6ycx6mVkLhRoOl9c5ZrmkovD3/0/S/zkTWoPW6Hkxs9MlLVAoYNFbkjgNnht3/9jdO7t7T3fvqVC/3GXuXpac5WaNaP5b9heFrmLJzDorVD58I5GLzFLRnJu3JZ0vSWbWX6GQVZHQVSKS5ZKuDr/L8ExJH7t7eTxfIKPLhe5+2MxmSFopKUfSg+6+1cxmSypz9+WSHlDo0u0uhRrkJiRvxdkhyvNyl6S2khaH34fwtrtflrRFZ4kozw0SLMrzslLShWa2TVKlpB+5O1flAxblublB0kIz+/8UaoKfwj/mg2dmjyn0D4/O4X64EknNJcnd/0eh/riLJe2S9Kmk78R9DZxnAACA+Mv0ciEAAEBSELIAAAACQMgCAAAIACELAAAgAIQsAACAABCyAAAAAkDIAgAACMD/D4PsuNx7wFcrAAAAAElFTkSuQmCC\n"
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ],
      "source": [
        "# Plot the predictions (these may need to be on a specific device)\n",
        "plot_predictions(predictions = y_preds.cpu())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s2OnlMWKjzX8"
      },
      "source": [
        "## 5. Save your trained model's `state_dict()` to file.\n",
        "  * Create a new instance of your model class you made in 2. and load in the `state_dict()` you just saved to it.\n",
        "  * Perform predictions on your test data with the loaded model and confirm they match the original model predictions from 4."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hgxhgD14qr-i",
        "outputId": "aa4e116e-f97b-4c28-e7bb-4b2738369fd7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Saving model to models/01_pytorch_model\n"
          ]
        }
      ],
      "source": [
        "from pathlib import Path\n",
        "\n",
        "# 1. Create models directory \n",
        "MODEL_PATH = Path(\"models\")\n",
        "MODEL_PATH.mkdir(parents = True,exist_ok = True)\n",
        "# 2. Create model save path \n",
        "MODEL_NAME = \"01_pytorch_model\"\n",
        "MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME \n",
        "# 3. Save the model state dict\n",
        "print(f\"Saving model to {MODEL_SAVE_PATH}\")\n",
        "torch.save(obj = model_1.state_dict(),f = MODEL_SAVE_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "P9vTgiLRrJ7T",
        "outputId": "5901a52c-5800-4790-e2d0-e998c2aa2b82"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "LinearRegressionModel()"
            ]
          },
          "metadata": {},
          "execution_count": 16
        }
      ],
      "source": [
        "# Create new instance of model and load saved state dict (make sure to put it on the target device)\n",
        "loaded_model = LinearRegressionModel()\n",
        "loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))\n",
        "loaded_model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8UGX3VebrVtI",
        "outputId": "eef16260-cf11-4895-b726-dedfa4e6b839"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True],\n",
              "        [True]], device='cuda:0')"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ],
      "source": [
        "# Make predictions with loaded model and compare them to the previous\n",
        "y_preds_new = loaded_model(X_test)\n",
        "y_preds == y_preds_new"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oQPeILR5_IDS",
        "outputId": "f784320a-f435-4502-d7cc-e66586c40919"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "OrderedDict([('weight', tensor([0.3067], device='cuda:0')),\n",
              "             ('bias', tensor([0.9011], device='cuda:0'))])"
            ]
          },
          "metadata": {},
          "execution_count": 18
        }
      ],
      "source": [
        "loaded_model.state_dict()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "include_colab_link": true
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.10.7 (tags/v3.10.7:6cc6b13, Sep  5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]"
    },
    "vscode": {
      "interpreter": {
        "hash": "2dcff298e6fc9f300c172f449ac5b974b753c4428b95c0ffa294019aef922779"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}